{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial\n",
    "\n",
    "In this notebook, we will see how to pass your own encoder and decoder's architectures to your VAE model using pythae!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.datasets as datasets\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n",
    "n_samples = 100\n",
    "dataset = mnist_trainset.data.reshape(-1, 1, 28, 28)[:n_samples] / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 10, figsize=(10, 2))\n",
    "for i in range(2):\n",
    "        for j in range(10):\n",
    "                axes[i][j].matshow(dataset[i*10 +j].reshape(28, 28), cmap='gray')\n",
    "                axes[i][j].axis('off')\n",
    "\n",
    "plt.tight_layout(pad=0.8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's build a custom auto-encoding architecture!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First thing, you need to import the ``BaseEncoder`` and ``BaseDecoder`` as well as ``ModelOutput`` classes from pythae by running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models.nn import BaseEncoder, BaseDecoder\n",
    "from pythae.models.base.base_utils import ModelOutput"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Then build your own architectures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class Encoder_Conv_VAE_MNIST(BaseEncoder):\n",
    "    def __init__(self, args):\n",
    "        BaseEncoder.__init__(self)\n",
    "\n",
    "        self.input_dim = (1, 28, 28)\n",
    "        self.latent_dim = args.latent_dim\n",
    "        self.n_channels = 1\n",
    "\n",
    "        self.conv_layers = nn.Sequential(\n",
    "            nn.Conv2d(self.n_channels, 128, 4, 2, padding=1),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(128, 256, 4, 2, padding=1),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(256, 512, 4, 2, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(512, 1024, 4, 2, padding=1),\n",
    "            nn.BatchNorm2d(1024),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.embedding = nn.Linear(1024, args.latent_dim)\n",
    "        self.log_var = nn.Linear(1024, args.latent_dim)\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        h1 = self.conv_layers(x).reshape(x.shape[0], -1)\n",
    "        output = ModelOutput(\n",
    "            embedding=self.embedding(h1),\n",
    "            log_covariance=self.log_var(h1)\n",
    "        )\n",
    "        return output\n",
    "\n",
    "\n",
    "class Decoder_Conv_AE_MNIST(BaseDecoder):\n",
    "    def __init__(self, args):\n",
    "        BaseDecoder.__init__(self)\n",
    "        self.input_dim = (1, 28, 28)\n",
    "        self.latent_dim = args.latent_dim\n",
    "        self.n_channels = 1\n",
    "\n",
    "        self.fc = nn.Linear(args.latent_dim, 1024 * 4 * 4)\n",
    "        self.deconv_layers = nn.Sequential(\n",
    "            nn.ConvTranspose2d(1024, 512, 3, 2, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(512, 256, 3, 2, padding=1, output_padding=1),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(256, self.n_channels, 3, 2, padding=1, output_padding=1),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "\n",
    "    def forward(self, z: torch.Tensor):\n",
    "        h1 = self.fc(z).reshape(z.shape[0], 1024, 4, 4)\n",
    "        output = ModelOutput(reconstruction=self.deconv_layers(h1))\n",
    "\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a model configuration (in which the latent will be stated). Here, we use the VAE model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import VAEConfig\n",
    "\n",
    "model_config = VAEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=16\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build your encoder and decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder_Conv_VAE_MNIST(model_config)\n",
    "decoder= Decoder_Conv_AE_MNIST(model_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Last but not least. Build you VAE model by passing the ``encoder`` and ``decoder`` arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import VAE\n",
    "\n",
    "model = VAE(\n",
    "    model_config=model_config,\n",
    "    encoder=encoder,\n",
    "    decoder=decoder\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now you can see the model that you've just built contains the custom autoencoder and decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### *note*: If you want to launch a training of such a model, try to ensure that the provided architectures are suited for the data. pythae performs a model sanity check before launching training and raises an error if the model cannot encode and decode an input data point"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines import TrainingPipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the training pipeline with your ``TrainingConfig`` instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_config = BaseTrainerConfig(\n",
    "    output_dir='my_model_with_custom_archi',\n",
    "    learning_rate=1e-3,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    steps_saving=None,\n",
    "    num_epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    model=model,\n",
    "    training_config=training_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Launch the ``Pipeline``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### *note 1*: You will see now that a ``encoder.pkl`` and ``decoder.pkl`` appear in the folder ``my_model_with_custom_archi/training_YYYY_MM_DD_hh_mm_ss/final_model`` to allow model rebuilding with your own architecture ``Encoder_Conv_VAE_MNIST`` and ``Decoder_Conv_AE_MNIST``.\n",
    "\n",
    "### *note 2*: Model rebuilding is based on the [dill](https://pypi.org/project/dill/) librairy allowing to reload the class whithout importing them. Hence, you should still be able to reload the model even if the classes ``Encoder_Conv_VAE_MNIST`` or ``Decoder_Conv_AE_MNIST`` were not imported.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_training = sorted(os.listdir('my_model_with_custom_archi'))[-1]\n",
    "print(last_training)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### You can now reload the model easily using the classmethod ``VAE.load_from_folder``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import AutoModel\n",
    "\n",
    "model_rec = AutoModel.load_from_folder(os.path.join('my_model_with_custom_archi', last_training, 'final_model'))\n",
    "model_rec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The model can now be used to generate new samples !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import NormalSampler\n",
    "\n",
    "\n",
    "sampler = NormalSampler(\n",
    "    model=model_rec\n",
    ")\n",
    "gen_data = sampler.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().reshape(28, 28), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "95022f601a219c6b6d093149c9a9b9a061a4446d3680d89cef8a1f82970031f2"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
